import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.stats import gaussian_kde

def FuncHMData(input_file_name, output_file_name, NumGrid):
    Z = np.genfromtxt(input_file_name, delimiter=',')
    x = np.linspace(-3, 11, NumGrid)  # Increase the number of points for smoother rendering
    y = np.linspace(-3, 11, NumGrid)
    dim = Z.shape[0]*Z.shape[1]
    output = np.zeros(shape=(dim,3))
    print("%.5f" % 8.99284722486562e-05)
    Z_smooth = gaussian_filter(Z, sigma=3)  # Adjust sigma for more or less smoothing
    for i in range(len(x)):
        for j in range(len(y)):
            output[i*Z.shape[0]+j][0] = x[i]
            output[i*Z.shape[0]+j][1] = y[j]
            output[i*Z.shape[0]+j][2] = Z_smooth[i][j]

    np.savetxt(output_file_name, output, delimiter=',', fmt="%.2f")

def TrajHMData(input_file_name, output_file_name):
    Traj = np.genfromtxt(
        input_file_name,
        delimiter=',')
    x = np.array(Traj[:,0])
    y = np.array(Traj[:,1])
    print(x.shape)
    # xmin, xmax = x.min() - 1, x.max() + 1  # Define x range with padding
    # ymin, ymax = y.min() - 1, y.max() + 1  # Define y range with padding
    xmin, xmax = -3, 11
    ymin, ymax = -3, 11
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]  # Create a mesh grid (100x100 resolution)
    # print('xx', xx)
    # print('yy', yy)

    positions = np.vstack([xx.ravel(), yy.ravel()])  # Flatten the grid for KDE evaluation
    # print(positions[0][1])
    kde = gaussian_kde(np.vstack([x, y]), bw_method=0.5)  # Fit KDE to trajectory data
    density = kde(positions).reshape(xx.shape)  # Compute density on the grid
    # Normalize density values to range [0,1]
    density = (density - density.min()) / (density.max() - density.min())
    print(density.shape)
    print(xx.shape)
    dim = density.shape[0] * density.shape[1]
    print('dim', dim)
    output = np.zeros(shape=(dim, 3))
    for i in range(density.shape[0]):
        for j in range(density.shape[1]):
            output[i*density.shape[0]+j][0] = xx[i][j]
            output[i*density.shape[0]+j][1] = yy[i][j]
            output[i*density.shape[0]+j][2] = density[i][j]

    np.savetxt(output_file_name, output, delimiter=',', fmt="%.2f")

def AddRunNumber(input_file_name, output_file_name):
    Z = np.genfromtxt(input_file_name, delimiter=',')
    dim = Z.shape[0]
    output = np.zeros(shape=(dim,3))
    print("%.5f" % 8.99284722486562e-05)
    Z_smooth = gaussian_filter(Z, sigma=3)  # Adjust sigma for more or less smoothing
    for i in range(dim):
        output[i][0] = int(i/201)
        output[i][1] = Z[i][0]
        output[i][2] = Z[i][1]
    np.savetxt(output_file_name, output, delimiter=',', fmt="%.2f")

def RunFilter(input_file_name, output_file_name, run):
    Z = np.genfromtxt(input_file_name, delimiter=',')
    dim = Z.shape[0]
    output = np.zeros(shape=(201,6))
    print("%.5f" % 8.99284722486562e-05)
    Z_smooth = gaussian_filter(Z, sigma=3)  # Adjust sigma for more or less smoothing
    print(Z[2][1])
    for i in range(dim):
        if Z[i][0] == run:
            for j in range(201):
                output[j][0] = run
                output[j][1] = Z[i+j][1]
                output[j][2] = Z[i+j][2]
                output[j][5] = 1 - 1 / 200 * j
                if j > 0:
                    output[j][3] = Z[i+j][1] - Z[i+j-1][1]
                    output[j][4] = Z[i+j][2] - Z[i+j-1][2]
            break
    np.savetxt(output_file_name, output, delimiter=',', fmt="%.2f")
FuncHMData('TrainingLossData.csv', 'TrainingLoss_transformed.csv', 100)
#TrajHMData('NoisyGDTrajHMData.csv', 'NoisyGDTraj_transformed.csv')
#AddRunNumber('NoisyGDTrajHMData.csv', 'NoisyGDTrajHM_indexed.csv')
#RunFilter('NoisyGDTrajHM_indexed.csv', 'NoisyGD_Filtered_Run_42.csv', 42.00)